# -*- coding: utf-8 -*-
"""KG_embedding_vanilla.ipynb

"""

#Mount google drive
from google.colab import drive

#mount google drive
drive.mount("/content/gdrive")

#install library required to read knowledge graph
!pip install rdflib

#install library for knowledge graph embeddings
!pip install ampligraph

"""#Main program starts here"""

from random import sample, shuffle, random
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from rdflib import Graph as RDFGraph
import numpy as np
from ampligraph.latent_features import ScoringBasedEmbeddingModel
import gc

#define knowledge graph utilities class
class CLEVRER_KG_Utils(object):

  @staticmethod
  def parseKG_from_file(kg_file): #parse CLEVRER KG in turtle format using rdflib and return
    kg = RDFGraph()
    kg.parse(kg_file, format='turtle')
    return kg

  @staticmethod
  def get_questions(kg): #get all questions from the CLEVRER KG
    kg_questions = {} #initialize place holder

    print ('processing kg triples ... ') #print statement for clarity in output

    for s,p,o in tqdm(kg): #use the hasQuestion relation to get question
      if 'hasQuestion' in p: #the object o is question
        question_iri = o
        for s2,p2,o2 in kg: #extract question string, use the hasStringValue relation to get question string
          if s2 == question_iri and 'hasStringValue' in p2:
            kg_questions[question_iri] = o2
    del question_iri; del s; del p; del o; del s2; del p2; del o2
    gc.collect()
    return kg_questions

  @staticmethod
  def get_correct_choices(kg,kg_questions): #get all correct choices for the questions

    correct_choices= dict() #place holder for all question's correct choices
    
    #print ('processing all questions ... ')
    for question in kg_questions:
      choices = [] #place holder for all question choices
      question_correct_choices = [] #place holder for correct question choices
      correct_choices[question] = [] #instantiate in place holder for all question correct choices
      for s,p,o in kg:
        #get answers to the question
        if s == question and 'hasChoices' in p:
          choices += [o]
      for choice in choices:
        for s,p,o in kg:
          #get correct choices
          if choice == s and 'hasStringValue' in p and 'correct' in o:
            question_correct_choices += [s]
      for choice in question_correct_choices:
        for s,p,o in kg:
          #get correct choice program and add to dict
          if choice == s and 'hasProgram' in p:
            correct_choices[question] += [str(o)]

    del question; del choices; del question_correct_choices; del s; del p; del o
    gc.collect()
    return correct_choices #return all question's correct choices
  
  @staticmethod
  def get_knowledge(question,kg): #get knowledge relevant to the question from kg

    relevant_knowledge = []
    video_id = question.split('__')[0].split('#')[-1] #get video id

    for s,p,o in kg: #get all the knowledge from the video
      if 'hasStringValue' in p or 'hasProgram' in p:
        continue #we dont want the string value or program value
      if 'hasQuestion' in p and o!=question:
        continue #dont care about other questions
      if 'hasChoices' in p: #dont care about question choices
        continue
      if 'type' in p and 'owl' in o: #omit owl types
        continue 
      if 'type' in p and 'trafficmonitoring' in o: #omit parent ontology types
        continue
      if video_id in s and ('Scene' in s or 'Observation' in s or 'Point' in s): #subject has something to do with the video
        relevant_knowledge.append([str(s),str(p),str(o)])
    
    del s; del p; del o
    gc.collect()
    return relevant_knowledge

  @staticmethod
  def get_question_embedding(question,kg,emb_size = 5,method = 'TransE',epochs = 5):

    rkg = CLEVRER_KG_Utils.get_knowledge(question,kg) #get the knowledge relevant to the question

    #construct node (subject and object) and relationship (predicate) embeddings using existing method
    model = ScoringBasedEmbeddingModel(k=emb_size, eta=1, scoring_type=method) #initialize TransE model, k = embedding size
    model.compile(optimizer='adam', loss='nll') #set optimizer and loss function, nll --> negative log likelihood
    rkg = np.array(rkg) #cast graph as numpy array
    model.fit(rkg, epochs=epochs) #train model
    return_value = model.get_embeddings([str(question)],embedding_type='e')
    del model; del rkg; gc.collect(); return return_value

#define tokenizer
class Tokenizer(object):

  def __init__(self,
               kg_questions,
               kg_choices):
    #this will be a character level tokenizer
    self.chars = set() #initializer character token list
    self.char_index = dict() #index dictionary for efficient lookup

    print ('Initializing tokenizer ... ')
    for question in tqdm(kg_questions):
      question_chars = list(kg_questions[question]) #add question characters
      for char in question_chars:
        self.chars.add(char)
      #format answer to separate question from choice 'CH', and primitive within the choice '<E>'
      formatted_choices = 'CH:'+('^_^<E>^_^'.join(kg_choices[question]))+'^_^<E>'
      choice_chars = list(formatted_choices) #add choice program characters
      for char in choice_chars:
        self.chars.add(char)
    del char; del formatted_choices; del choice_chars

    self.chars = list(self.chars) #convert to list to create index
    self.vocab_size = len(self.chars)
    for i in range(self.vocab_size): #create index
      char = self.chars[i]; self.char_index[char] = i
    gc.collect()

  def encode(self,
             string):
    
    #return index of each character from the index dictionary
    return [self.char_index[char] for char in list(string)]

  def decode(self,
             encoding):
    
    #lookup self.chars for each index in the encoding, 
    #join the chars into string and return
    return ''.join([self.chars[i] for i in encoding])

#define dataloader class
class Dataloader(object):

  def __init__(self,
               tokenizer):
    
    self.tokenizer = tokenizer #store tokenizer

  def get_batch(self,
                kg_questions,
                kg_answers,
                n = None):

    X, Y = [],[] #place holders for the data
    question_batch = list(kg_questions.keys()); shuffle(question_batch) #use all questions (shuffled)
    
    #print ('processing questions ... ')
    for question in question_batch:
      #get question encoding
      question_encoding = self.tokenizer.encode(kg_questions[question])
      #format answer to separate question from choice 'CH', and primitive within the choice '<E>'
      formatted_choices = 'CH:'+('^_^<E>^_^'.join(kg_answers[question]))+'^_^<E>'
      choice_encodings = self.tokenizer.encode(formatted_choices)
      #concatenate the two because we want to generate both
      full_encoding = question_encoding + choice_encodings
      context_size = len(full_encoding)
      #create datapoint with question, it's encoding, and choice encodings
      for t in range(context_size-1):
        x, y = full_encoding[:t+1], full_encoding[t+1]
        X += [(question,x)]; Y += [y]

    #consolidate all data points for all questions
    if n is None:
      self.data = list([list(item) for item in zip(X,Y)])
      del question_encoding; del formatted_choices; del choice_encodings; del full_encoding; del context_size; del x; del y; del X; del Y
    else:
      self.data = sample(list([list(item) for item in zip(X,Y)]),n)
      del question_encoding; del formatted_choices; del choice_encodings; del full_encoding; del context_size; del x; del y; del X; del Y
    gc.collect()

#define class for general tensor utilities
class Tensor_Utils(object):

  @staticmethod
  def normed(T): #function to normalize the tensor T

    #get norm of the tensor
    norm = torch.linalg.norm(T)

    #divide by norm
    return (torch.div(T,norm.item()))

#define generator class that will input question,
#and generate the correct choice programs
class Generator(nn.Module):

  def __init__(self,
               vocab_size = None,
               emb_size = None,
               context_size = None,
               n_heads = None,
               kg = None): #it uses a single multiheaded self-attention block

    super().__init__() #call superclass constructor

    #store config data
    self.vocab_size = vocab_size
    self.emb_size = emb_size
    self.context_size = context_size
    self.n_heads = n_heads
    self.kg = kg

    #embedding layer with position encodings
    self.embeddings = nn.Embedding(self.vocab_size, self.emb_size)
    self.pos_embeddings = nn.Embedding(self.context_size, self.emb_size)

    #query, key, value matrix and self-attention operation
    self.query = nn.Linear(self.emb_size,self.emb_size,bias=False)
    self.key = nn.Linear(self.emb_size,self.emb_size,bias=False)
    self.value = nn.Linear(self.emb_size,self.emb_size,bias=False)
    self.multihead_attn = nn.MultiheadAttention(self.emb_size,self.n_heads)

    #classification head
    self.head = nn.Linear(self.emb_size,self.vocab_size)
    self.attn_output_weights = None #place holder for attentino weights

  def forward(self,
              question_data):
    
    kg = self.kg #shorthand
    question_encoding = question_data[0][1][-self.context_size:] #get tokenized question
    question = question_data[0][0] #get question
    question_knowledge_vector = CLEVRER_KG_Utils.get_question_embedding(question, #get knowledge embedding vector for the question
                                                                        kg,
                                                                        emb_size = self.emb_size,
                                                                        method='TransE',
                                                                        epochs = 1)
    
    question_knowledge_vector = torch.from_numpy(question_knowledge_vector).float() #get question vector from knowlege graph
    question_tensor = torch.tensor(question_encoding) #convert to tensor
    n_tokens = len(question_tensor)
    question_embedding = self.embeddings(question_tensor) #get question embedding with position encodings
    question_embedding += self.pos_embeddings(torch.arange(n_tokens))
    question_embedding += question_knowledge_vector
    Q = self.query(question_embedding) #multi-headed self-attention computation
    K = self.key(question_embedding)
    V = self.value(question_embedding)
    question_embedding, self.attn_output_weights = self.multihead_attn(Q,K,V)
    
    logits = F.leaky_relu(self.head(question_embedding))[-1] #get logits by extracting last column of size vocab_size and return
    return logits

  def train(self,
            kg_questions,
            kg_choices,
            dataloader_object,
            batch_size = 32,
            epochs = 100): #training function for the generator
    
    optimizer = torch.optim.AdamW(self.parameters()) #initialize optimizer

    print ('Starting training loop ... ')
    for i in tqdm(range(epochs)): #training loop
      dl.get_batch(kg_questions, kg_choices, n = batch_size) #get a batch of data
      n_batch = len(dl.data) #calculate no. of data points
      loss = F.cross_entropy #set loss to cross entropy loss

      batch_loss = 0.0 #initialize batch_loss
      #print ('processing batch ... ')
      for j in range(n_batch):
        data_point = dl.data[j] #get datapoint
        x, y = data_point, data_point[-1] #get (x,y) pair
        logits = self(x) #compute forward pass
        #compute one hot encoding for targets
        targets = [0.0]*self.vocab_size; targets[y] = 1.0
        targets = torch.tensor(targets) #convert to tensor
        batch_loss += loss(logits,targets) #add to total batch loss

      batch_loss /= n_batch #compute average batch loss
      print ('batch loss: ',batch_loss.item()) #print batch loss to check convergence
      #perform optimization step
      batch_loss.backward()
      nn.utils.clip_grad_norm_(self.parameters(), 1.0)
      optimizer.step()
      optimizer.zero_grad()
      del dl.data; del loss; del batch_loss; del n_batch; del i; del j; del logits; del targets;
    gc.collect()

UNIT_TEST = False
if UNIT_TEST:
  kg = CLEVRER_KG_Utils.parseKG_from_file("gdrive/MyDrive/AGNNs/small_copy.ttl")
  kg_questions = CLEVRER_KG_Utils.get_questions(kg)
  kg_choices = CLEVRER_KG_Utils.get_correct_choices(kg,kg_questions)
  t = Tokenizer(kg_questions,kg_choices)
  dl = Dataloader(t)
  g = Generator(vocab_size=t.vocab_size,
                emb_size = 96,
                context_size = 100,
                n_heads = 12,
                kg = kg)
  g.train(kg_questions,kg_choices,dl,batch_size = 1,epochs=100)
  del kg; del kg_questions; del kg_choices; del t; del dl; del g; gc.collect()
